# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from gpu.host import DeviceContext
from layout import Layout, LayoutTensor, RuntimeLayout, UNKNOWN_VALUE
from layout._fillers import random
from memory import LegacyUnsafePointer as UnsafePointer
from nn.conv import conv_cudnn, conv_gpu
from testing import assert_almost_equal

from utils.index import IndexList


# input: NHWC
# filer: RSCF
fn test_conv_cudnn[
    input_dim: IndexList[4],
    filter_dim: IndexList[4],
    output_dim: IndexList[4],
    input_type: DType,
    filter_type: DType,
    output_type: DType,
    stride_dim: IndexList[2],
    dilation_dim: IndexList[2],
    pad_dim: IndexList[2],
    num_groups: Int = 1,
](ctx: DeviceContext) raises:
    print(
        "== test_cudnn_conv_gpu: dtype_in=",
        input_type,
        " dtype_filter=",
        filter_type,
        " dtype_out=",
        output_type,
        " num_groups=",
        num_groups,
    )

    # Extract dimensions
    comptime N = input_dim[0]
    comptime H = input_dim[1]
    comptime W = input_dim[2]
    comptime C_in = input_dim[3]

    comptime R = filter_dim[0]
    comptime S = filter_dim[1]
    comptime C = filter_dim[2]
    comptime F = filter_dim[3]

    comptime Nout = output_dim[0]
    comptime Hout = output_dim[1]
    comptime Wout = output_dim[2]
    comptime Cout = output_dim[3]

    # Define layouts
    comptime input_layout = Layout.row_major(N, H, W, C_in)
    comptime filter_layout = Layout.row_major(R, S, C, F)
    comptime filter_nchw_layout = Layout.row_major(F, C, R, S)
    comptime output_layout = Layout.row_major(Nout, Hout, Wout, Cout)

    comptime input_dim_flattened = N * H * W * C_in
    comptime filter_dim_flattened = R * S * C * F
    comptime output_dim_flattened = Nout * Hout * Wout * Cout

    # Allocate host memory
    var input_host_ptr = UnsafePointer[Scalar[input_type]].alloc(
        input_dim_flattened
    )
    var filter_host_ptr = UnsafePointer[Scalar[filter_type]].alloc(
        filter_dim_flattened
    )
    var filter_nchw_host_ptr = UnsafePointer[Scalar[filter_type]].alloc(
        filter_dim_flattened
    )
    var output_ref_host_ptr = UnsafePointer[Scalar[output_type]].alloc(
        output_dim_flattened
    )
    var output_host_ptr = UnsafePointer[Scalar[output_type]].alloc(
        output_dim_flattened
    )

    # Create host LayoutTensors
    var input_host = LayoutTensor[input_type, input_layout](input_host_ptr)
    var filter_host = LayoutTensor[filter_type, filter_layout](filter_host_ptr)
    var filter_nchw_host = LayoutTensor[filter_type, filter_nchw_layout](
        filter_nchw_host_ptr
    )
    var output_ref_host = LayoutTensor[output_type, output_layout](
        output_ref_host_ptr
    )
    var output_host = LayoutTensor[output_type, output_layout](output_host_ptr)

    random(input_host)
    random(filter_host)

    # Transpose filter to NCHW
    for r in range(R):
        for s in range(S):
            for c in range(C):
                for f in range(F):
                    filter_nchw_host[f, c, r, s] = filter_host[r, s, c, f]

    _ = output_host.fill(0)
    _ = output_ref_host.fill(0)

    # Allocate device buffers
    var input_dev = ctx.enqueue_create_buffer[input_type](input_dim_flattened)
    var filter_dev = ctx.enqueue_create_buffer[filter_type](
        filter_dim_flattened
    )
    var filter_nchw_dev = ctx.enqueue_create_buffer[filter_type](
        filter_dim_flattened
    )
    var output_dev = ctx.enqueue_create_buffer[output_type](
        output_dim_flattened
    )
    var output_ref_dev = ctx.enqueue_create_buffer[output_type](
        output_dim_flattened
    )

    # Create device LayoutTensors
    var input_dev_tensor = LayoutTensor[input_type, input_layout](
        input_dev.unsafe_ptr()
    )
    var filter_dev_tensor = LayoutTensor[filter_type, filter_layout](
        filter_dev.unsafe_ptr()
    )
    var filter_nchw_dev_tensor = LayoutTensor[filter_type, filter_nchw_layout](
        filter_nchw_dev.unsafe_ptr()
    )
    var output_dev_tensor = LayoutTensor[output_type, output_layout](
        output_dev.unsafe_ptr()
    )
    var output_ref_dev_tensor = LayoutTensor[output_type, output_layout](
        output_ref_dev.unsafe_ptr()
    )

    ctx.enqueue_copy(input_dev, input_host_ptr)
    ctx.enqueue_copy(filter_dev, filter_host_ptr)
    ctx.enqueue_copy(filter_nchw_dev, filter_nchw_host_ptr)

    conv_gpu[
        input_layout,
        filter_layout,
        output_layout,
        input_type,
        filter_type,
        output_type,
    ](
        input_dev_tensor.as_any_origin(),
        filter_dev_tensor.as_any_origin(),
        output_ref_dev_tensor.as_any_origin(),
        stride_dim,
        dilation_dim,
        pad_dim,
        num_groups,
        ctx,
    )

    conv_cudnn[input_type, filter_type, output_type](
        input_dev_tensor,
        filter_nchw_dev_tensor,
        output_dev_tensor,
        stride_dim,
        dilation_dim,
        pad_dim,
        num_groups,
        ctx,
    )

    ctx.enqueue_copy(output_ref_host_ptr, output_ref_dev)
    ctx.enqueue_copy(output_host_ptr, output_dev)
    ctx.synchronize()

    # Verify results
    for n in range(Nout):
        for h in range(Hout):
            for w in range(Wout):
                for f in range(Cout):
                    assert_almost_equal(
                        output_host[n, h, w, f],
                        output_ref_host[n, h, w, f],
                        rtol=0.01,
                    )
    print("Succeed")

    # Cleanup host memory
    input_host_ptr.free()
    filter_host_ptr.free()
    filter_nchw_host_ptr.free()
    output_ref_host_ptr.free()
    output_host_ptr.free()

    # Cleanup device buffers
    _ = input_dev^
    _ = filter_dev^
    _ = filter_nchw_dev^
    _ = output_dev^
    _ = output_ref_dev^


def main():
    with DeviceContext() as ctx:
        # Test configurations for data types.
        comptime dtype_configs = (DType.float32, DType.float16, DType.bfloat16)

        test_conv_cudnn[
            IndexList[4](1, 1, 550, 1024),  # input  (NHWC)
            IndexList[4](
                1, 7, 1024, 1024
            ),  # filter (RSCF) (height, width, in_channels, out_channels)
            IndexList[4](1, 1, 550, 1024),  # output (NHWC)
            DType.float32,
            DType.float32,
            DType.float32,
            IndexList[2](1, 1),  # stride
            IndexList[2](1, 1),  # dilation
            IndexList[2](0, 3),  # pad
        ](ctx)

        # Test different data types.
        @parameter
        for i in range(len(dtype_configs)):
            comptime dtype = dtype_configs[i]

            test_conv_cudnn[
                IndexList[4](1, 8, 8, 16),  # input  (NHWC)
                IndexList[4](3, 3, 16, 32),  # filter (RSCF)
                IndexList[4](1, 6, 6, 32),  # output (NHWC)
                dtype,
                dtype,
                dtype,
                IndexList[2](1, 1),  # stride
                IndexList[2](1, 1),  # dilation
                IndexList[2](0, 0),  # pad
            ](ctx)

        # Test grouped convolutions
        # NOTE: Grouped convolutions are not supported by conv_gpu's naive implementation,
        # so we cannot validate against it. These tests would need a different reference
        # implementation or manual validation.

        # TODO(KERN-1846): Add grouped convolution tests once we have a proper
        # reference implementation.
        # The following configurations would be tested:
        # - Standard conv with num_groups=1
        # - 2 groups
        # - 4 groups
        # - Depthwise convolution (num_groups = in_channels)
        # - Grouped convolution with float16

    # Test with multiple device contexts consecutively
    print("\n== Testing with multiple device contexts ==")

    # First context - default device (GPU 0)
    print("Creating first device context (default device)...")
    with DeviceContext() as ctx1:
        test_conv_cudnn[
            IndexList[4](1, 8, 8, 16),  # input  (NHWC)
            IndexList[4](3, 3, 16, 32),  # filter (RSCF)
            IndexList[4](1, 6, 6, 32),  # output (NHWC)
            DType.float32,
            DType.float32,
            DType.float32,
            IndexList[2](1, 1),  # stride
            IndexList[2](1, 1),  # dilation
            IndexList[2](0, 0),  # pad
        ](ctx1)

    if DeviceContext.number_of_devices() >= 2:
        # Second context - device 1
        print("Creating second device context (device 1)...")
        with DeviceContext(device_id=1) as ctx2:
            test_conv_cudnn[
                IndexList[4](1, 8, 8, 16),  # input  (NHWC)
                IndexList[4](3, 3, 16, 32),  # filter (RSCF)
                IndexList[4](1, 6, 6, 32),  # output (NHWC)
                DType.float32,
                DType.float32,
                DType.float32,
                IndexList[2](1, 1),  # stride
                IndexList[2](1, 1),  # dilation
                IndexList[2](0, 0),  # pad
            ](ctx2)

        print("Multiple device context test completed successfully!")
